<!-- Pre and post-processing functions from: https://github.com/AndreyGermanov/yolov8_onnx_javascript -->
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>YOLOv8 tinygrad WebGL</title>
    <script src="./net.js"></script>
    <style>
       body {
            text-align: center;
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 0;
            overflow: hidden;
        }

        .video-container {
            position: relative;
            width: 100%;
            margin: 0 auto;
        }

        #video, #canvas {
            position: absolute;
            top: 0;
            left: 0;
            width: 100%;
            height: auto;
        }

        #canvas {
            background: transparent;
        }

        h1 {
            margin-top: 20px;
        }
    </style>
</head>
<body>
    <h1>YOLOv8 tinygrad WebGL</h1>
    <div class="video-container">
        <video id="video" muted autoplay playsinline></video>
        <canvas id="canvas"></canvas>
    </div>
    
    <script>
        let net = null;

        const video = document.getElementById('video');
        const canvas = document.getElementById('canvas');
        const context = canvas.getContext('2d');
        const offscreenCanvas = document.createElement('canvas');
        offscreenCanvas.width = 640;
        offscreenCanvas.height = 640;
        const offscreenContext = offscreenCanvas.getContext('2d');

        if (navigator.mediaDevices && navigator.mediaDevices.getUserMedia) {
            navigator.mediaDevices.getUserMedia({ audio: false, video: true }).then(function (stream) {
                video.srcObject = stream;
                video.onloadedmetadata = function() {
                    canvas.width = video.clientWidth;
                    canvas.height = video.clientHeight;
                }
            });
        }

        async function processFrame() {
            offscreenContext.drawImage(video, 0, 0, 640, 640);
            const boxes = await detectObjectsOnFrame(offscreenContext);
            drawBoxes(offscreenCanvas, boxes);
            requestAnimationFrame(processFrame);
        }

        requestAnimationFrame(processFrame);

        function drawBoxes(offscreenCanvas, boxes) {
            const canvas = document.querySelector("canvas");
            const ctx = canvas.getContext("2d");
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            ctx.lineWidth = 3;
            ctx.font = "20px serif";
            const scaleX = canvas.width / 640;
            const scaleY = canvas.height / 640;

            boxes.forEach(([x1, y1, x2, y2, label]) => {
                const classIndex = yolo_classes.indexOf(label);
                const color = classColors[classIndex];
                const textWidth = ctx.measureText(label).width;
                ctx.strokeStyle = color;
                ctx.fillStyle = color;

                let adjustedX1 = x1 * scaleX;
                let adjustedY1 = y1 * scaleY;
                let adjustedX2 = x2 * scaleX;
                let adjustedY2 = y2 * scaleY;
                let boxWidth = adjustedX2 - adjustedX1;
                let boxHeight = adjustedY2 - adjustedY1;

                ctx.strokeRect(adjustedX1, adjustedY1, boxWidth, boxHeight);
                ctx.fillRect(adjustedX1, adjustedY1 - 25, textWidth + 10, 25);
                ctx.fillStyle = "#000000";
                ctx.fillText(label, adjustedX1, adjustedY1 - 7);
            });
        }

        async function detectObjectsOnFrame(offscreenContext) {
            if (!net) net = await loadNet();
            let start = performance.now();
            const [input,img_width,img_height] = await prepareInput(offscreenContext);
            console.log("Preprocess took: " + (performance.now() - start) + " ms");
            start = performance.now();
            const output = net(new Float32Array(input));
            console.log("Inference took: " + (performance.now() - start) + " ms");
            start = performance.now();
            let out = processOutput(output,img_width,img_height);
            console.log("Postprocess took: " + (performance.now() - start) + " ms");
            return out;
        }

        async function prepareInput(offscreenContext) {
            return new Promise(resolve => {
                const [img_width,img_height] = [640, 640]
                const imgData = offscreenContext.getImageData(0,0,640,640);
                const pixels = imgData.data;
                const red = [], green = [], blue = [];

                for (let index=0; index<pixels.length; index+=4) {
                    red.push(pixels[index]/255.0);
                    green.push(pixels[index+1]/255.0);
                    blue.push(pixels[index+2]/255.0);
                }
                const input = [...red, ...green, ...blue];
                resolve([input, img_width, img_height])
            })
        }
        
        const loadNet = async () => {
            try {
                const safetensor = await (new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer()));
                const gl = document.createElement("canvas").getContext("webgl2");
                return setupNet(gl, safetensor);
            } catch (e) {
                console.log(e);
                return null;
            }
        }

        function processOutput(output, img_width, img_height) {
            let boxes = [];
            for (let index=0;index<8400;index++) {
                const [class_id,prob] = [...Array(80).keys()]
                    .map(col => [col, output[8400*(col+4)+index]])
                    .reduce((accum, item) => item[1]>accum[1] ? item : accum,[0,0]);
                if (prob < 0.25) {
                    continue;
                }
                const label = yolo_classes[class_id];
                const xc = output[index];
                const yc = output[8400+index];
                const w = output[2*8400+index];
                const h = output[3*8400+index];
                const x1 = (xc-w/2)/640*img_width;
                const y1 = (yc-h/2)/640*img_height;
                const x2 = (xc+w/2)/640*img_width;
                const y2 = (yc+h/2)/640*img_height;
                boxes.push([x1,y1,x2,y2,label,prob]);
            }

            boxes = boxes.sort((box1,box2) => box2[5]-box1[5])
            const result = [];
            while (boxes.length>0) {
                result.push(boxes[0]);
                boxes = boxes.filter(box => iou(boxes[0],box)<0.7);
            }
            return result;
        }

        function iou(box1,box2) {
            return intersection(box1,box2)/union(box1,box2);
        }

        function union(box1,box2) {
            const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
            const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
            const box1_area = (box1_x2-box1_x1)*(box1_y2-box1_y1)
            const box2_area = (box2_x2-box2_x1)*(box2_y2-box2_y1)
            return box1_area + box2_area - intersection(box1,box2)
        }

        function intersection(box1,box2) {
            const [box1_x1,box1_y1,box1_x2,box1_y2] = box1;
            const [box2_x1,box2_y1,box2_x2,box2_y2] = box2;
            const x1 = Math.max(box1_x1,box2_x1);
            const y1 = Math.max(box1_y1,box2_y1);
            const x2 = Math.min(box1_x2,box2_x2);
            const y2 = Math.min(box1_y2,box2_y2);
            return (x2-x1)*(y2-y1)
        }

        const yolo_classes = [
            'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
            'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
            'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
            'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
            'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
            'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
            'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
            'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
        ];

        function generateColors(numColors) {
            const colors = [];
            for (let i = 0; i < 360; i += 360 / numColors) {
                colors.push(`hsl(${i}, 100%, 50%)`);
            }
            return colors;
        }

        const classColors = generateColors(yolo_classes.length);
    </script>
</body>
</html>
